#!/usr/bin/python
#-*-coding:utf-8-*-

###############################################################
## Name       : gen_link
## Author     : xiaotu
## Time       : 2021-07-10 10:12:48
## Description:
## 
## 
###############################################################

import sys
import os
import re
import argparse
import math

class Inst:#{{{
    '''
    因为一个模块可能会例化多次
    因此例化的信息必须要放在这里
    '''
    def __init__(self, name, module):
        self.name   = name
        self.module = module
        self.port_link = {}
        self.real_port_link = {}
        self.real_para_link = {}
        self.port_list = []
        self.para_list = []
        self.inst_with_para    = []
        self.inst_without_para = []
    def get_port_inst(self, port, link):
        self.port_link[port] = link.strip()
        #print(self.name, port, link)
    def get_module(self, module):
        self.module = module
    def split_port_para(self):
        for org in self.port_link.keys():
            #print(org, self.port_link[org])
            if self.module.judge_sig(org) == 1:
                #print("sig")
                self.real_port_link[org] = self.port_link[org]
            else:
                #print("PARA")
                self.real_para_link[org] = str(self.port_link[org])
    def get_port_list(self, list = []):
        self.port_list = list
    def get_para_list(self, list = []):
        self.para_list = list
    def get_inst_with_para(self, list = []):
        self.inst_with_para = list
    def get_inst_without_para(self, list = []):
        self.inst_without_para = list
#}}}

class Module:#{{{
    def __init__(self, name):
        self.name = name
        self.path = ""
        self.sig_hash = {}
        self.input_hash = {}
        self.output_hash = {}
    def get_rtl_file(self, path):
        self.path = path.strip()
        #print(self.path)
    def get_sig(self, sig):
        self.sig_hash[sig.name] = sig
        #print(self.name + " add port: " + sig.name)
    def split_port(self):
        for sig in self.sig_hash.keys():
            if self.sig_hash[sig].port == "input":
                self.input_hash[sig] = self.sig_hash[sig]
            if self.sig_hash[sig].port == "output":
                self.output_hash[sig] = self.sig_hash[sig]
            #print(self.sig_hash[sig].name)
    def judge_sig(self, sig):
        for real_sig in self.sig_hash.keys():
            if re.search(r""+sig, real_sig):
                #print("check: " + sig, real_sig)
                return 1
        return 0
#}}}

class Signal:#{{{
    name_width = 1
    widh_width = 1
    def __init__(self, name, type="wire", port="none", width = ""):
        self.name   = name
        self.type   = type
        self.port   = port
        self.width  = width
        Signal.name_width  = max(Signal.name_width, len(name))
        Signal.widh_width  = max(Signal.widh_width, len(width))

    def __str__(self):
        str = "Class Signal\n"
        str = str + "    name   : %s\n" % self.name
        str = str + "    type   : %s\n" % self.type
        str = str + "    port   : %s\n" % self.port
        str = str + "    width  : %s\n" % self.width
        str = str + "    max name_width = %s\n" % Signal.name_width
        return str
#}}}

def exit(note = ""):#{{{
    print(note)
    os.exit(0)
pass#}}}

def debug_print(in_str):#{{{
    if debug:
        print(in_str)
pass#}}}

def input_args_proc():#{{{
    global rtl_file
    global del_flag
    global debug
    rtl_handle = []
    parser = argparse.ArgumentParser(description="argparse info")
    parser.add_argument('-o', action='store_true', default=False, help='open this script')
    parser.add_argument('-f', help='input rtl.v')
    parser.add_argument('-d', action='store_true', default=False, help='delete gen_link code')
    parser.add_argument('-debug', action='store_true', default=False, help='debug')
    result = parser.parse_args()
    if result.o == True:
        os.system("gvim %s" % __file__)
        sys.exit(0)
    rtl_file = result.f
    del_flag = result.d
    debug = result.debug
    with open (rtl_file, "r") as rtl:
        rtl_handle = rtl.readlines()
    return rtl_handle
pass#}}}

def match_rep(org, to, inst):#{{{
    '''
    根据org到to的正则变换规则
    对inst进行正则变换
    '''
    ret = inst
    
    def rep_m(match):
        i = 1
        rep = to
        for key in match.groups():
            rep = re.sub(r"#" + str(i) + "#", key, rep)
            rep = re.sub(r"#" + str(i), key, rep)
            i += 1
        return rep
            
    if re.search(r"^%s$" % org, inst):
        #print(org, inst)
        ret = re.sub(r"^%s$" % org, rep_m, inst)
    return ret
pass#}}}

def head_tail_split(handle, head = "\n", tail = "\n", mode = 0):#{{{
    '''
    mode用来标记首尾是否被包含
    mode = 0首尾行都不要/1首尾行都要/2要首行不要尾行/3要尾行不要首行
    '''
    ret = []
    shot_en = 0
    head_en = 0
    tail_en = 0
    for line in handle:
        re0 = re.search(r"%s" % head, line)
        re1 = re.search(r"%s" % tail, line)
        if re0 and re1: #一定是尾巴，但是不一定是头
            tail_en = 1
            shot_en = 0
            if shot_en == 0:#还没有命中，那么一定是头了
                head_en = 1
        elif re0:
            if shot_en == 0:#命中了头
                head_en = 1
            shot_en = 1
        elif re1:
            if shot_en == 1:#一定是尾巴
                tail_en = 1
            shot_en = 0
        else:
            head_en = 0
            tail_en = 0

        if mode == 0:
            if shot_en == 1 and head_en == 0 and tail_en == 0:
                ret.append(line)
        elif mode == 1:
            if shot_en == 1 or head_en == 1 or tail_en == 1:
                ret.append(line)
        elif mode == 2:
            if (shot_en == 1 or head_en == 1) and tail_en == 0:
                ret.append(line)
        else:
            if (shot_en == 1 or tail_en == 1) and head_en == 0:
                ret.append(line)
    return(ret)
pass#}}}

def get_src_rtl(rtl_handle):#{{{
    '''
    用来获取gen link的各种rtl文件
    '''
    src_rtl_list = []
    tmp = []
    src_rtl_list.append(r"./")
    tmp = head_tail_split(rtl_handle, "SUB MODULE LIST START", "SUB MODULE LIST END")
    for line0 in tmp:
        lst0 = line0.split(",")
        for line1 in lst0:
            lst1 = re.search("\"(.*)\"", line1).groups()
            for fl in lst1:
                if fl != "./":
                    src_rtl_list.append(fl)
    return src_rtl_list
pass#}}}

def link_module(rtl_handle):#{{{
    inst_hash = {}
    module_hash = {}
    shot_en = 0
    module_name = ""
    inst_name   = ""
    for line in rtl_handle:
        re0 = re.search(r"/\*\s*(\w+)\s+(\w+)\s+LINK_MODULE", line)
        re1 = re.search(r"\*/", line)
        re2 = re.search(r"\.(.*)\s*\((.*)\)", line)
        if shot_en == 1:
            if re2:
                #print(re2.groups())
                port = re2.group(1).strip()
                link = re2.group(2).strip()
                #print(port, link)
                inst_hash[inst_name].get_port_inst(port, link)
        if re0:
            shot_en = 1
            module_name = re0.group(1)
            inst_name   = re0.group(2)
            #print(module_name, inst_name)
            module = Module(module_name)
            inst   = Inst(inst_name, module)
            inst_hash[inst_name] = inst
            module_hash[module_name] = module
        if re1:
            if shot_en == 1:
                inst_hash[inst_name] = inst
            shot_en = 0
    #print(inst_hash)
    return inst_hash, module_hash
pass#}}}

def sys_rtl_file(module_hash, src_rtl_list):#{{{
    for module in module_hash.keys():
        #print(module)
        declare = "module " + module
        for fl in src_rtl_list:
            rtl_file = fl + "/*.v"
            #print(declare, rtl_file)
            #print(r"NOTE::::grep '%s' %s -l" % (declare, rtl_file))
            find_list = os.popen(r"grep '%s' %s -l" % (declare, rtl_file)).readlines()
            if len(find_list) > 0:
                module_hash[module].get_rtl_file(find_list[0])
                break
            #print(find)
    for module in module_hash.keys():
        #print(module)
        if module_hash[module].path == "":
            exit("module " + module + " need rtl file!!!")
        else:
            #print(module_hash[module].path)
            pass
    return module_hash
pass#}}}

def width_space_ex(matched):#{{{
    return " " + matched.group(1) + " "
#}}}

def sys_rtl_sig_para(inst_hash, module_hash):#{{{
    for module in module_hash.keys():
        file_path = module_hash[module].path
        with open (file_path, "r") as rtl:
            handle = rtl.readlines()
            rtl_line = head_tail_split(handle, "^module " + module, r"^endmodule", 1)
            for line in rtl_line:
                line = line.strip()
                line = re.sub("(\[.*\])", width_space_ex, line)
                re1  = re.match(r"(input|output|wire|reg)(\s+wire|\s+reg)?\s*(\[.*\])?\s*([\s,\w]+)\s*", line)
                #re1 = re.search(r"^\s*(input|output)\s*(wire|reg)*\s*(\[.*\])?\s*([\s,\w]+)\s*", line)
                type = "wire"
                port = "none"
                width = ""
                if re1:
                    if re1.group(3):
                        width = re1.group(3)
                        #print(width)
                    if re.match(r"input|output", re1.group(1)):
                        port = re1.group(1).strip()
                    else:
                        type = re1.group(1).strip()
                    if re1.group(2):
                        type = re1.group(2).strip()
                    for sig in re1.group(4).split(","):
                        name = sig.strip()
                        if name != "":
                            s = Signal(name, type, port, width)
                            if(module == "ipg_aou_sid_gen"):
                                debug_print(s)
                            module_hash[module].get_sig(s)
    for module in module_hash.keys():
        module_hash[module].split_port()
    for inst in inst_hash.keys():
        module = inst_hash[inst].module.name
        inst_hash[inst].get_module(module_hash[module])
        inst_hash[inst].split_port_para()
    return inst_hash, module_hash
pass#}}}

def gen_inst(inst_hash, module_hash):#{{{
    para_list = []#给每一个inst的
    port_list = []#给每一个inst的
    input_declare  = {}#给全局的
    output_declare = {}#给全局的
    wire_declare   = {}#给全局的
    
    for inst_name in inst_hash.keys():
        ignore_flag = 0
        para_list = []
        port_list = []
        inst   = inst_hash[inst_name]
        module = inst.module
        module_name = module.name
        #print("gen_inst()", inst_name, module_name)
        
        port_list.append("    //output port inst")
        for output in module.output_hash.keys():
            ignore_flag = 0
            #print(inst_name, output)
            output_inst = output
            note = ""
            for port in inst.real_port_link.keys():
                port_inst =  inst.real_port_link[port]
                if output != match_rep(port, port_inst, output):
                    output_inst = match_rep(port, port_inst, output)
                    if output_inst == "":
                        ignore_flag = 1
                        #print(port, output_inst, ignore_flag)
                    note = "//FROM .%s(%s)" % (port, port_inst)
                    break
            str = r".%s(%s)," % (output, output_inst) + note
            port_list.append("    " + str)
            #print(inst_name, output, output_inst, inst.module.sig_hash[output].width)
            #str = r"%s %s;" % (inst.module.sig_hash[output].width, output_inst)
            #print(inst_name, output, output_inst, inst.module.sig_hash[output].width)
            inst_width = inst.module.sig_hash[output].width
            for para in inst.real_para_link.keys():
                #print(para)
                if re.search(r"%s\W" % para, inst_width):
                    inst_width = re.sub(para, inst.real_para_link[para], inst_width)
            if ignore_flag == 0:
                str = r"%s %s;" % (inst_width, output_inst)
                str = re.sub(r"\s+", "", str)
                output_declare[output_inst] = str
                #print(output_inst + " output: " + str)
        
        port_list.append("    //input port inst")
        for input in module.input_hash.keys():
            #print(inst_name, input)
            input_inst = input
            note = ""
            for port in inst.real_port_link.keys():
                port_inst =  inst.real_port_link[port]
                if input != match_rep(port, port_inst, input):
                    input_inst = match_rep(port, port_inst, input)
                    note = "//FROM .%s(%s)" % (port, port_inst)
                    break
            if input ==  module.input_hash.keys()[-1]:
                str = r".%s(%s)" % (input, input_inst) + note
            else:
                str = r".%s(%s)," % (input, input_inst) + note
            port_list.append("    " + str)
            #str = r"%s %s;" % (inst.module.sig_hash[input].width, input_inst)
            #print(str)
            inst_width = inst.module.sig_hash[input].width
            for para in inst.real_para_link.keys():
                #print(para)
                if re.search(r"%s\W" % para, inst_width):
                    inst_width = re.sub(para, inst.real_para_link[para], inst_width)
            if not re.search(r"'", input_inst):
                str = r"%s %s;" % (inst_width, input_inst)
                str = re.sub(r"\s+", "", str)
                input_declare[input_inst] = str

        for para in inst.real_para_link.keys():
            if para == inst.real_para_link.keys()[-1]:
                str = r".%s(%s)" % (para,  inst.real_para_link[para])
            else:
                str = r".%s(%s)," % (para,  inst.real_para_link[para])
            para_list.append("    " + str)
        
        inst_hash[inst_name].get_port_list(port_list)
        inst_hash[inst_name].get_para_list(para_list)
        
    for dec in sorted(output_declare.keys()):
        if dec in input_declare.keys():
            wire_declare[dec] =  output_declare[dec]
            output_declare.pop(dec)
            input_declare.pop(dec)
            #print(dec, wire_declare[dec])
    #print(wire_declare)

    return inst_hash, output_declare, input_declare, wire_declare
pass#}}}

def gen_link(rtl_handle, inst_hash, output_declare, input_declare, wire_declare):#{{{
    tmp_list = []
    has_dec_list = []

    for line in rtl_handle:#如果已经有声明，那么不重复进行声明
         res = re.search(r"^\s*(input|output|wire|reg)\s*(wire|reg)?\s*(\[.*\])?\s*([\s,\w]+)\s*", line)
         if res:
            for sig in res.group(4).split(r"\s*,\s*"):
                debug_print("PORT--" + sig)
                has_dec_list.append(sig.strip())

    for inst_name in inst_hash:
        tmp_list = []
        inst = inst_hash[inst_name]
        module = inst.module

        if len(inst.para_list) > 0:
            tmp_list.append(module.name + " #(")
            tmp_list.extend(inst.para_list)
            tmp_list.append(") " + inst.name + "(")
        else:
            tmp_list.append(module.name + " " + inst.name + "(")
        tmp_list.extend(inst.port_list)
        tmp_list.append(");")

        inst_hash[inst_name].get_inst_with_para(tmp_list)
        
        tmp_list = []
        tmp_list.append(module.name + " " + inst.name + "(")
        tmp_list.extend(inst.port_list)
        tmp_list.append(");")
        inst_hash[inst_name].get_inst_without_para(tmp_list)

    input_flag  = 0
    output_flag = 0
    wire_flag   = 0#没有input_flag or output_flag，則声明为wire类型
    for line in rtl_handle:
        if re.search(r"//GEN_INPUT", line):
            input_flag = 1
        if re.search(r"//GEN_OUTPUT", line):
            output_flag = 1
        if re.search(r"//GEN_WIRE", line):
            wire_flag = 1
    #print(input_flag, output_flag, wire_flag)
    
    tmp_list = []
    for line in rtl_handle:#开始输出，先放到tmp_list中
        line = line.rstrip()
        res = re.search(r"(\w+)\s*(#\(/\*GEN_PARA\*/\))?\s*(\w+)\s*(\(/\*GEN_LINK\*/\))", line)
        if re.search(r"//GEN_INPUT", line):
            tmp_list.append(line)
            tmp_list.append("/*GEN_LINK START*/")
            for line1 in sorted(input_declare.keys()):
                if line1 not in has_dec_list:
                    tmp_list.append("input " + input_declare[line1])
            tmp_list.append("/*GEN_LINK END*/")
            #tmp_list.append("")
        elif re.search(r"//GEN_OUTPUT", line):
            tmp_list.append(line)
            tmp_list.append("/*GEN_LINK START*/")
            for line1 in sorted(output_declare.keys()):
                if line1 not in has_dec_list:
                    tmp_list.append("output " + output_declare[line1])
            tmp_list.append("/*GEN_LINK END*/")
            #tmp_list.append("")
        elif re.search(r"//GEN_WIRE", line):
            tmp_list.append(line)
            tmp_list.append("/*GEN_LINK START*/")
            for line1 in sorted(wire_declare.keys()):
                if line1 not in has_dec_list:
                    tmp_list.append("wire " + wire_declare[line1])
            if input_flag == 0:
                for line1 in sorted(input_declare.keys()):
                    if line1 not in has_dec_list:
                        tmp_list.append("wire " + input_declare[line1] + " //org = input")
            if output_flag == 0:
                for line1 in sorted(output_declare.keys()):
                    if line1 not in has_dec_list:
                        tmp_list.append("wire " + output_declare[line1] + " //org = output")
            tmp_list.append("/*GEN_LINK END*/")
            #tmp_list.append("")
        elif res:
            #print(res.groups())
            module = res.group(1)
            inst   = res.group(3)
            tmp_list.append("//" + line + " GEN_LINK_ADD")
            tmp_list.append("/*GEN_LINK START*/")
            #print(inst, line)
            if res.group(2):
                tmp_list.extend(inst_hash[inst].inst_with_para)
            else:
                #print(inst)
                tmp_list.extend(inst_hash[inst].inst_without_para)
            tmp_list.append("/*GEN_LINK END*/")

        else:
            tmp_list.append(line)

    return tmp_list
pass#}}}

def gen_port(file_handle):#{{{
    input_list  = []
    output_list = []
    tmp_list    = []
    input_list.append("//input port")
    output_list.append("//output port  /*GEN_LINK START*/")
    input_str  = ""
    output_str = ""
    for line in file_handle:
        res = re.search(r"^\s*(input|output)\s*(wire|reg)?\s*(\[.*\])?\s*([\s,\w]+)\s*", line)
        if res:
            #print(line)
            port = res.group(1).strip()
            for sig in res.group(4).split(r"\s*,\s*"):
                name = sig.strip(",")
                name = name.strip()
                #print(port, name)
                if port == "input":
                    input_str  = input_str + name + ","
                else:
                    output_str = output_str + name + ","
                if len(input_str) > 60:
                    input_list.append(input_str)
                    input_str = ""
                if len(output_str) > 60:
                    output_list.append(output_str)
                    output_str = ""
    if input_str != "":
        input_list.append(input_str)
    if output_str != "":
        output_list.append(output_str)
    input_list[-1] = re.sub(r",$", "", input_list[-1])
    input_list.append("/*GEN_LINK END*/")
    
    port_str = r"/*GEN_PORT*/\n"
    for line in output_list:
        port_str = port_str + line + "\n"
    for line in input_list:
        port_str = port_str + line + "\n"
    #print(port_str)

    for line in file_handle:
        line = re.sub(r"/\*GEN_PORT\*/", port_str, line)
        tmp_list.append(line)
    return tmp_list
pass#}}}

def del_gen_link_code(handle):#{{{
    tmp_list = []
    shot_en = 0
    for line in handle:
        line = line.rstrip()
        if re.search(r"//(.*) GEN_LINK_ADD", line):
            line = re.search(r"//(.*) GEN_LINK_ADD", line).group(1)
        if re.search(r"/\*GEN_LINK START\*/", line):
            shot_en = 1
        if shot_en == 0:
            tmp_list.append(line)
        if re.search(r"/\*GEN_LINK END\*/", line):
            shot_en = 0
    return tmp_list
#}}}

def main():#{{{
    global rtl_file
    global del_flag
    rtl_handle   = []#用来存放当前的
    src_rtl_list = []
    out_rtl      = []
    inst_hash    = {}#GEN_LINK内标注的信息
    module_hash  = {}
    real_hash    = {}#真正例化的模块，指向例化的代码list
    input_declare  = {}#给全局的
    output_declare = {}#给全局的
    wire_declare   = {}#给全局的
    #print("//%s is working \n" % __file__)

    rtl_handle = input_args_proc()

    if del_flag:
        out_rtl = del_gen_link_code(rtl_handle)
    else:
        rtl_handle   = del_gen_link_code(rtl_handle)
        #print(rtl_handle)
        src_rtl_list = get_src_rtl(rtl_handle)
        (inst_hash, module_hash) = link_module(rtl_handle)
        module_hash = sys_rtl_file(module_hash, src_rtl_list)
        (inst_hash, module_hash) = sys_rtl_sig_para(inst_hash, module_hash)
        #到这里，把要分析的信息分析完了，准备生成link

        (inst_hash, output_declare, input_declare, wire_declare) = gen_inst(inst_hash, module_hash)
        out_rtl = gen_link(rtl_handle, inst_hash, output_declare, input_declare, wire_declare)
        out_rtl = gen_port(out_rtl)
    
    if not debug:
        for line in out_rtl:
            print(line)
            pass
#}}}

if __name__ == "__main__":
    main()
